import numpy as np

scale = 0.2
N = 256
t = np.linspace(-scale, scale, N)
print('t', t.shape)

grid = np.meshgrid(t, t, t)
for i, item in enumerate(grid):
    print(i, item.shape)
query_pts = np.stack(grid, -1).astype(np.float32)
print('points shape:', query_pts.shape)
sh = query_pts.shape
flat = query_pts.reshape([-1, 3])
print('flat shape:', flat.shape)
